#!/usr/bin/python3
###############################################################################
'''文本用户界面'''
# Copyright (c) 2022 Xu Ruijun | 1687701765@qq.com
#
# This file is part of Electronic Analog Filter Design Tool(eAFDTool).
#
# eAFDTool is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or any later version.
#
# eAFDTool is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# eAFDTool. If not, see <https://www.gnu.org/licenses/>.
###############################################################################
import curses
import time

from .expr2str import num_SIprefix
from .parser_str import parser_SInum
from .curr_str import vblocks

__author__ = "Xu Ruijun"
__copyright__ = "Copyright (c) 2022 Xu Ruijun"
__license__ = "GPLv3 or later"


def lblocks(x, l):
    assert l%8 == 0
    assert 0 <= x <= l
    head = vblocks[-1]*(x//8)
    if x == l:
        return head
    mid = vblocks[x % 8]
    if x//8 == l//8 - 1:
        return head + mid
    tail = vblocks[0]*(l//8 - x//8 - 1)
    return head + mid + tail

def draw_table(win, wides, height, Nw=None, Nh=None):
    if isinstance(wides, int):
        wides = [wides]*Nw
    if isinstance(height, int):
        height = [height]*Nh
    line = list(map(lambda x:'─'*x, wides))
    win.addstr(0, 1, '┬'.join(line))
    space = list(map(lambda x:' '*x, wides))
    space_hoz = '│'.join(space)
    line_spt = '├' + '┼'.join(line) + '┤'
    y = 1
    for i, h in enumerate(height):
        for j in range(h):
            win.addstr(y, 1, space_hoz)
            y += 1
        if i != len(height) - 1:
            win.addstr(y, 0, line_spt)
            y += 1
    win.addstr(y, 1, '┴'.join(line))


class Adj:
    def __init__(self, win, names, values=None):
        self.win = win
        self.names = names
        if values is None:
            values = [0] * len(names)
        self.values = values
        self.bar_i = 0

    def init_win(self):
        N = len(self.names)
        l = max(map(lambda x:len(x), self.names))
        win_wide = self.win.getmaxyx()[1] - self.win.getyx()[1]
        self.win.resize(2*N+1, win_wide)
        self.win.border(0)
        self.bar_wide = win_wide - l - 5 - 4
        draw_table(self.win, [l, self.bar_wide, 5], 1, Nh=N)
        for i, name in enumerate(self.names):
            self.win.addstr(2*i+1, 1, name)
        self.win.refresh()

    def set_bar(self, bar_i, value):
        if 0 <= value <= self.bar_wide*8:
            self.win.addstr(2*bar_i+1, 6, lblocks(value, self.bar_wide*8))

    def adj_event(self, bar_i, delta):
        # add process in subclass
        self.values[bar_i] += delta
        self.set_bar(bar_i, self.values[bar_i])
        self.win.refresh()

    def input_num(self, bar_i, last):
        self.win.addch(2*bar_i+1, 7+self.bar_wide, last)
        self.win.refresh()
        self.win.move(2*bar_i+1, 8+self.bar_wide)
        curses.echo()
        curses.nocbreak()
        s = self.win.getstr()
        s = chr(last) + s.decode()
        self.input_finish(bar_i, s)
        curses.noecho()
        curses.cbreak()

    def input_finish(self, bar_i, s):
        # add process in subclass
        pass

    def set_str(self, bar_i, s):
        self.win.addstr(2*bar_i+1, 7+self.bar_wide, s)

    def change_cursor_bar(self, bar_i):
        self.win.move(2*bar_i+1, self.bar_wide+6)
        self.win.refresh()

    def rxkey(self, keycode):
        if keycode == curses.KEY_UP:
            if self.bar_i > 0:
                self.bar_i -= 1
                self.change_cursor_bar(self.bar_i)
        elif keycode == curses.KEY_DOWN:
            if self.bar_i < len(self.names) - 1:
                self.bar_i += 1
                self.change_cursor_bar(self.bar_i)
        elif keycode == curses.KEY_LEFT:
            self.adj_event(self.bar_i, -1)
        elif keycode == curses.KEY_RIGHT:
            self.adj_event(self.bar_i, 1)
        elif ord('0') <= keycode <= ord('9') or keycode == ord('.'):
            self.input_num(self.bar_i, keycode)
        elif keycode == ord('\t'):
            self.change_cursor_bar(self.bar_i)


class SK_tAdj(Adj):
    def __init__(self, win, parent):
        names = parent.solver.tname
        super().__init__(win, names)
        self.parent = parent
        self.values = parent.solver.initv()
        self.maxv = None
        self.N_keep = parent.solver.Nord - 1
        self.near_change = list(range(self.N_keep))

    def update_values(self):
        self.maxv = max(self.values)
        for i, v in enumerate(self.values):
            if v < 0:
                block_v = 0
            else:
                block_v = int(v/self.maxv*self.bar_wide*8)
            self.set_bar(i, block_v)
            self.set_str(i, num_SIprefix(v))
        self.change_cursor_bar(self.bar_i)

    def change_event(self, bar_i, value):
        bi2kw = self.parent.solver.tchr
        if bar_i in self.near_change:
            self.near_change.remove(bar_i)
        else:
            self.near_change.pop(0)
        self.near_change.append(bar_i)
        self.values[bar_i] = value
        d = dict(map(lambda i:(bi2kw[i], self.values[i]), self.near_change))
        assert len(d) == self.N_keep
        try:
            self.values = self.parent.solver.t_adj(**d)
        except Exception as e:
            print(e)
        else:
            self.parent.update_all()
        self.change_cursor_bar(self.bar_i)

    def adj_event(self, bar_i, delta):
        v = self.values[bar_i] + delta*self.maxv/(self.bar_wide*8)
        self.change_event(bar_i, v)

    def input_finish(self, bar_i, s):
        v = parser_SInum(s)
        self.change_event(bar_i, v)


class SK_rcAdj(Adj):
    def __init__(self, win, parent):
        names = map(lambda s:s.upper().center(4), parent.solver.rcn)
        super().__init__(win, list(names))
        self.lock_bar = 0
        self.parent = parent
        self.values = parent.solver.calc_rc(r1=10e3)

    def update_values(self):
        self.maxr = max(map(lambda i:self.values[i], self.parent.solver.rx))
        self.maxc = max(map(lambda i:self.values[i], self.parent.solver.cx))
        for i, v in enumerate(self.values):
            if v < 0:
                block_v = 0
            else:
                if i < len(self.parent.solver.rx):
                    block_v = int(v/self.maxr*self.bar_wide*8)
                else:
                    block_v = int(v/self.maxc*self.bar_wide*8)
            self.set_bar(i, block_v)
            self.set_str(i, num_SIprefix(v))

    def change_event(self, bar_i, value):
        self.values[bar_i] = value
        self.lock_bar = bar_i
        self.re_calc()
        self.change_cursor_bar(bar_i)

    def re_calc(self):
        bi2kw = self.parent.solver.rcn
        kwargs = {bi2kw[self.lock_bar]: self.values[self.lock_bar]}
        try:
            self.values = self.parent.solver.calc_rc(**kwargs)
        except Exception:
            pass
        else:
            self.update_values()

    def adj_event(self, bar_i, delta):
        if bar_i < len(self.parent.solver.rx):
            maxv = self.maxr
        else:
            maxv = self.maxc
        v = self.values[bar_i] + delta*maxv/(self.bar_wide*8)
        self.change_event(bar_i, v)

    def input_finish(self, bar_i, s):
        v = parser_SInum(s)
        self.change_event(bar_i, v)


class SK_Adjs:
    def __init__(self, solver):
        self.solver = solver

    def run(self, win):
        Nt = len(self.solver.tname)
        Nrc = len(self.solver.rcn)
        twin = win.derwin(Nt*2+1, 50, 0, 0)
        rcwin = win.derwin(Nrc*2+1, 50, 15, 0)
        self.t_adj = SK_tAdj(twin, self)
        self.rc_adj = SK_rcAdj(rcwin, self)
        self.subadj = [self.t_adj, self.rc_adj]
        for adj in self.subadj:
            adj.init_win()
            adj.update_values()
        curr = 0
        while True:
            keycode = win.getch()
            if keycode == ord('q'):
                break
            if keycode == ord('\t'):
                curr = (curr + 1) % len(self.subadj)
            self.subadj[curr].rxkey(keycode)

    def update_all(self):
        self.t_adj.update_values()
        self.rc_adj.re_calc()
        self.rc_adj.win.refresh()
